Skip to content

Conversation

shino16
Copy link
Collaborator

@shino16 shino16 commented Sep 29, 2025

Fixes #2539, which in turn fixes #2527 and #2501. This makes PR #2538 obsolete.

As in #2527 (comment), when we pass a GraphModule to torch.compile it doesn't get lowered properly. As a workaround, this PR wraps the fallback GraphModule in a new nn.Module. This workaround was found in #2527 (comment) (credit @mattteochen).

Alternatively we could call torch._inductor.compile_fx.compile_fx directly to compile the GraphModule, but as it returns a bare forward function instead of a nn.Module instance, it has troubles with registering as a submodule to the outer GraphModule.

@shino16 shino16 force-pushed the wrap-inductor-submodule branch from f3311ea to ec67c18 Compare September 29, 2025 15:23
@shino16
Copy link
Collaborator Author

shino16 commented Sep 29, 2025

The test failures are coming from this:

import torch

class GraphModule(torch.nn.Module):
    def forward(self, y: "f32[2, 2]"):
        # No stacktrace found for following nodes
        _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None)

         # File: /opt/pytorch/lightning-thunder/thunder/tests/test_dynamo.py:280 in func, code: y = torch.sinc(y)
        y_1: "f32[2, 2]" = torch.sinc(y);  y = None

        # No stacktrace found for following nodes
        _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
        return y_1

model = GraphModule()
y = torch.randn(2, 2)
torch.compile(model)(y)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 338, in _fn
    assert guards.check(), (
           ^^^^^^^^^^^^^^
AssertionError: Global autocast state changed while dynamo tracing, please report a bug

@shino16 shino16 marked this pull request as draft September 30, 2025 10:39
@shino16
Copy link
Collaborator Author

shino16 commented Sep 30, 2025

PyTorch checks that the global state is kept unchanged while tracing (ref). torch.autocast() context manager circumvents this by reverting it in its cleanup process (ref), but torch.amp.autocast_mode._enter_autocast does not have such mechanism, causing the change in global state.

It may be better to rely on compile_fx instead of wrapping the GraphModule and tracing it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ThunderFX's fallback is not using Inductor compilation Activation checkpoint not working inside Inductor-compiled submodules
1 participant